package main

func maxDepth(root *TreeNode) int {
	if root == nil {
		return 0
	}
	var dfs func(*TreeNode, int) int
	dfs = func(node *TreeNode, depth int) int {
		if node == nil {
			return depth
		}
		dfs(node.Left, depth+1)
		return max(dfs(node.Left, depth+1), dfs(node.Right, depth+1))
	}
	return dfs(root, 0)
}

func max(i, j int) int {
	if i > j {
		return i
	} else {
		return j
	}
}

func main() {
	node5 := &TreeNode{Val: 7}
	node4 := &TreeNode{Val: 15}
	node3 := &TreeNode{Val: 20, Left: node4, Right: node5}
	node2 := &TreeNode{Val: 9}
	node1 := &TreeNode{Val: 3, Left: node2, Right: node3}
	depth := maxDepth(node1)
	println(depth)

}
